{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Caching Nodes with Hamilton [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/caching_nodes/caching_graph_adapter/caching_nodes.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/caching_nodes/caching_graph_adapter/caching_nodes.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Loaders\n",
    "# When developing, we would likely want to cache our data loaders because of latencies in loading data from networked sources or slow disks.\n",
    "# Functions marked by `tag(cache=\"SERIALIZATION_FORMAT\")` are automatically cached by the CachingGraphAdapter (discussed later).\n",
    "\n",
    "from hamilton.function_modifiers import tag\n",
    "\n",
    "spends_data = [10, 10, 20, 40, 40, 50]\n",
    "signups_data = [1, 10, 50, 100, 200, 400]\n",
    "\n",
    "@tag(cache=\"parquet\")\n",
    "def spend() -> pd.Series:\n",
    "    \"\"\"Emulates potentially expensive data extraction.\"\"\"\n",
    "    return pd.Series(spends_data)\n",
    "\n",
    "\n",
    "@tag(cache=\"parquet\")\n",
    "def signups() -> pd.Series:\n",
    "    \"\"\"Emulates potentially expensive data extraction.\"\"\"\n",
    "    return pd.Series(signups_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Functions holding bussiness logic\n",
    "\n",
    "\n",
    "def avg_3wk_spend(spend: pd.Series) -> pd.Series:\n",
    "    \"\"\"Rolling 3 week average spend.\"\"\"\n",
    "    return spend.rolling(3).mean()\n",
    "\n",
    "\n",
    "def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:\n",
    "    \"\"\"The cost per signup in relation to spend.\"\"\"\n",
    "    return spend / signups\n",
    "\n",
    "\n",
    "def spend_mean(spend: pd.Series) -> float:\n",
    "    \"\"\"Shows function creating a scalar. In this case it computes the mean of the entire column.\"\"\"\n",
    "    return spend.mean()\n",
    "\n",
    "\n",
    "def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:\n",
    "    \"\"\"Shows function that takes a scalar. In this case to zero mean spend.\"\"\"\n",
    "    return spend - spend_mean\n",
    "\n",
    "\n",
    "def spend_std_dev(spend: pd.Series) -> float:\n",
    "    \"\"\"Function that computes the standard deviation of the spend column.\"\"\"\n",
    "    return spend.std()\n",
    "\n",
    "\n",
    "def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:\n",
    "    \"\"\"Function showing one way to make spend have zero mean and unit variance.\"\"\"\n",
    "    return spend_zero_mean / spend_std_dev\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Place the functions into a temporary module -- the idea is that this should house a curated set of functions.\n",
    "# Don't be afraid to make multiple of them -- however we'd advise you to not use this method for production.\n",
    "# Also note, that using a temporary function module does not work for scaling onto Ray, Dask, or Pandas on Spark.\n",
    "from hamilton import ad_hoc_utils\n",
    "\n",
    "\n",
    "data_loaders = ad_hoc_utils.create_temporary_module(\n",
    "    spend, signups, module_name=\"data_loaders\"\n",
    ")\n",
    "\n",
    "business_logic = ad_hoc_utils.create_temporary_module(\n",
    "    avg_3wk_spend, \n",
    "    spend_per_signup,\n",
    "    spend_mean,\n",
    "    spend_zero_mean,\n",
    "    spend_std_dev,\n",
    "    spend_zero_mean_unit_variance, \n",
    "    module_name=\"business_logic\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hamilton import base, driver\n",
    "from hamilton.experimental import h_cache\n",
    "import pathlib\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   spend  signups  avg_3wk_spend  spend_per_signup  spend_zero_mean_unit_variance\n",
      "0     10        1            NaN            10.000                      -1.064405\n",
      "1     10       10            NaN             1.000                      -1.064405\n",
      "2     20       50      13.333333             0.400                      -0.483821\n",
      "3     40      100      23.333333             0.400                       0.677349\n",
      "4     40      200      33.333333             0.200                       0.677349\n",
      "5     50      400      43.333333             0.125                       1.257934\n"
     ]
    }
   ],
   "source": [
    "# This is empty, we get the data from the data_loaders module\n",
    "initial_columns = {}\n",
    "\n",
    "# Initialise the cache directory\n",
    "cache_path = \"tmp\"\n",
    "pathlib.Path(cache_path).mkdir(exist_ok=True)\n",
    "\n",
    "adapter = h_cache.CachingGraphAdapter(cache_path, base.PandasDataFrameResult())\n",
    "dr = driver.Driver(initial_columns, business_logic, data_loaders, adapter=adapter)\n",
    "output_columns = [\n",
    "    \"spend\",\n",
    "    \"signups\",\n",
    "    \"avg_3wk_spend\",\n",
    "    \"spend_per_signup\",\n",
    "    \"spend_zero_mean_unit_variance\",\n",
    "]\n",
    "\n",
    "df = dr.execute(output_columns)\n",
    "print(df.to_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets change the source values for our data loaders. \n",
    "\n",
    "spends_data = [i * 1000 for i in spends_data]\n",
    "signups_data = [i * 1000 for i in spends_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   spend  signups  avg_3wk_spend  spend_per_signup  spend_zero_mean_unit_variance\n",
      "0     10        1            NaN            10.000                      -1.064405\n",
      "1     10       10            NaN             1.000                      -1.064405\n",
      "2     20       50      13.333333             0.400                      -0.483821\n",
      "3     40      100      23.333333             0.400                       0.677349\n",
      "4     40      200      33.333333             0.200                       0.677349\n",
      "5     50      400      43.333333             0.125                       1.257934\n"
     ]
    }
   ],
   "source": [
    "# Since the data loaders are cached, they should continue returning the old values.\n",
    "\n",
    "spends_data = [i * 1000 for i in spends_data]\n",
    "signups_data = [i * 1000 for i in spends_data]\n",
    "\n",
    "# CachingGraphAdapter handles the actual caching during exection.\n",
    "adapter = h_cache.CachingGraphAdapter(cache_path, base.PandasDataFrameResult())\n",
    "\n",
    "# Hamilton caches are valid accross new instances of the driver. \n",
    "dr = driver.Driver(initial_columns, business_logic, data_loaders, adapter=adapter)\n",
    "output_columns = [\n",
    "    \"spend\",\n",
    "    \"signups\",\n",
    "    \"avg_3wk_spend\",\n",
    "    \"spend_per_signup\",\n",
    "    \"spend_zero_mean_unit_variance\",\n",
    "]\n",
    "\n",
    "df = dr.execute(output_columns)\n",
    "print(df.to_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   spend   signups  avg_3wk_spend  spend_per_signup  spend_zero_mean_unit_variance\n",
      "0  10000  10000000            NaN             0.001                      -1.064405\n",
      "1  10000  10000000            NaN             0.001                      -1.064405\n",
      "2  20000  20000000   13333.333333             0.001                      -0.483821\n",
      "3  40000  40000000   23333.333333             0.001                       0.677349\n",
      "4  40000  40000000   33333.333333             0.001                       0.677349\n",
      "5  50000  50000000   43333.333333             0.001                       1.257934\n"
     ]
    }
   ],
   "source": [
    "# Now lets force hamilton to recompute the cached data loaders.\n",
    "\n",
    "adapter = h_cache.CachingGraphAdapter(cache_path, base.PandasDataFrameResult(), force_compute=set([\"spend\", \"signups\"]))\n",
    "dr = driver.Driver(initial_columns, business_logic, data_loaders, adapter=adapter)\n",
    "output_columns = [\n",
    "    \"spend\",\n",
    "    \"signups\",\n",
    "    \"avg_3wk_spend\",\n",
    "    \"spend_per_signup\",\n",
    "    \"spend_zero_mean_unit_variance\",\n",
    "]\n",
    "\n",
    "df = dr.execute(output_columns)\n",
    "print(df.to_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hamilton-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
